{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. Dependency Parsing "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I recommend you take a look at these material first."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture6.pdf\n",
    "* http://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf\n",
    "* https://github.com/rguthrie3/DeepDependencyParsingProblemSet/tree/master/data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import nltk\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter, OrderedDict\n",
    "import nltk\n",
    "from nltk.tree import Tree\n",
    "import os\n",
    "from IPython.display import Image, display\n",
    "from nltk.draw import TreeWidget\n",
    "from nltk.draw.util import CanvasFrame\n",
    "flatten = lambda l: [item for sublist in l for item in sublist]\n",
    "random.seed(1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "USE_CUDA = torch.cuda.is_available()\n",
    "gpus = [0]\n",
    "torch.cuda.set_device(gpus[0])\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getBatch(batch_size, train_data):\n",
    "    random.shuffle(train_data)\n",
    "    sindex = 0\n",
    "    eindex = batch_size\n",
    "    while eindex < len(train_data):\n",
    "        batch = train_data[sindex: eindex]\n",
    "        temp = eindex\n",
    "        eindex = eindex + batch_size\n",
    "        sindex = temp\n",
    "        yield batch\n",
    "    \n",
    "    if eindex >= len(train_data):\n",
    "        batch = train_data[sindex:]\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_sequence(seq, to_index):\n",
    "    idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"<unk>\"], seq))\n",
    "    return Variable(LongTensor(idxs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Borrowed from https://stackoverflow.com/questions/31779707/how-do-you-make-nltk-draw-trees-that-are-inline-in-ipython-jupyter\n",
    "\n",
    "def draw_nltk_tree(tree):\n",
    "    cf = CanvasFrame()\n",
    "    tc = TreeWidget(cf.canvas(), tree)\n",
    "    tc['node_font'] = 'arial 15 bold'\n",
    "    tc['leaf_font'] = 'arial 15'\n",
    "    tc['node_color'] = '#005990'\n",
    "    tc['leaf_color'] = '#3F8F57'\n",
    "    tc['line_color'] = '#175252'\n",
    "    cf.add_widget(tc, 50, 50)\n",
    "    cf.print_to_file('tmp_tree_output.ps')\n",
    "    cf.destroy()\n",
    "    os.system('convert tmp_tree_output.ps tmp_tree_output.png')\n",
    "    display(Image(filename='tmp_tree_output.png'))\n",
    "    os.system('rm tmp_tree_output.ps tmp_tree_output.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transition State Class "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It tracks transition state(current stack, buffer) and extracts its feature for neural dependancy parser"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/05.transition-based-parse.png\">\n",
    "<center>borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture6.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class TransitionState(object):\n",
    "    \n",
    "    def __init__(self, tagged_sent):\n",
    "        self.root = ('ROOT', '<root>', -1)\n",
    "        self.stack = [self.root]\n",
    "        self.buffer = [(s[0], s[1], i) for i, s in enumerate(tagged_sent)]\n",
    "        self.address = [s[0] for s in tagged_sent] + [self.root[0]]\n",
    "        self.arcs = []\n",
    "        self.terminal=False\n",
    "        \n",
    "    def __str__(self):\n",
    "        return 'stack : %s \\nbuffer : %s' % (str([s[0] for s in self.stack]), str([b[0] for b in self.buffer]))\n",
    "    \n",
    "    def shift(self):\n",
    "        \n",
    "        if len(self.buffer) >= 1:\n",
    "            self.stack.append(self.buffer.pop(0))\n",
    "        else:\n",
    "            print(\"Empty buffer\")\n",
    "            \n",
    "    def left_arc(self, relation=None):\n",
    "        \n",
    "        if len(self.stack) >= 2:\n",
    "            arc = {}\n",
    "            s2 = self.stack[-2]\n",
    "            s1 = self.stack[-1]\n",
    "            arc['graph_id'] = len(self.arcs)\n",
    "            arc['form'] = s1[0]\n",
    "            arc['addr'] = s1[2]\n",
    "            arc['head'] = s2[2]\n",
    "            arc['pos'] = s1[1]\n",
    "            if relation:\n",
    "                arc['relation'] = relation\n",
    "            self.arcs.append(arc)\n",
    "            self.stack.pop(-2)\n",
    "            \n",
    "        elif self.stack == [self.root]:\n",
    "            print(\"Element Lacking\")\n",
    "    \n",
    "    def right_arc(self, relation=None):\n",
    "        \n",
    "        if len(self.stack) >= 2:\n",
    "            arc = {}\n",
    "            s2 = self.stack[-2]\n",
    "            s1 = self.stack[-1]\n",
    "            arc['graph_id'] = len(self.arcs)\n",
    "            arc['form'] = s2[0]\n",
    "            arc['addr'] = s2[2]\n",
    "            arc['head'] = s1[2]\n",
    "            arc['pos'] = s2[1]\n",
    "            if relation:\n",
    "                arc['relation'] = relation\n",
    "            self.arcs.append(arc)\n",
    "            self.stack.pop(-1)\n",
    "            \n",
    "        elif self.stack == [self.root]:\n",
    "            print(\"Element Lacking\")\n",
    "    \n",
    "    def get_left_most(self, index):\n",
    "        left=['<NULL>', '<NULL>', None]\n",
    "        \n",
    "        if index == None: \n",
    "            return left\n",
    "        for arc in self.arcs:\n",
    "            if arc['head'] == index:\n",
    "                left = [arc['form'], arc['pos'], arc['addr']]\n",
    "                break\n",
    "        return left\n",
    "    \n",
    "    def get_right_most(self, index):\n",
    "        right=['<NULL>', '<NULL>', None]\n",
    "        \n",
    "        if index == None: \n",
    "            return right\n",
    "        for arc in reversed(self.arcs):\n",
    "            if arc['head'] == index:\n",
    "                right=[arc['form'], arc['pos'], arc['addr']]\n",
    "                break\n",
    "        return right\n",
    "    \n",
    "    def is_done(self):\n",
    "        return len(self.buffer) == 0 and self.stack == [self.root]\n",
    "    \n",
    "    def to_tree_string(self):\n",
    "        if self.is_done() == False: \n",
    "            return None\n",
    "        ingredient = []\n",
    "        for arc in self.arcs:\n",
    "            ingredient.append([arc['form'], self.address[arc['head']]])\n",
    "        ingredient = ingredient[-1:] + ingredient[:-1]\n",
    "        return self._make_tree(ingredient, 0)\n",
    "    \n",
    "    def _make_tree(self, ingredient, i, new=True):\n",
    "    \n",
    "        if new:\n",
    "            treestr = \"(\"\n",
    "            treestr += ingredient[i][0]\n",
    "            treestr += \" \"\n",
    "        else:\n",
    "            treestr = \"\"\n",
    "        ingredient[i][0] = \"CHECK\"\n",
    "\n",
    "        parents,_ = list(zip(*ingredient))\n",
    "\n",
    "        if ingredient[i][1] not in parents:\n",
    "            treestr += ingredient[i][1]\n",
    "            return treestr\n",
    "\n",
    "        else:\n",
    "            treestr += \"(\"\n",
    "            treestr += ingredient[i][1]\n",
    "            treestr += \" \"\n",
    "            for node_i, node in enumerate(parents):\n",
    "                if node == ingredient[i][1]:\n",
    "                    treestr += self._make_tree(ingredient, node_i, False)\n",
    "                    treestr += \" \"\n",
    "\n",
    "            treestr = treestr.strip()\n",
    "            treestr += \")\"\n",
    "        if new:\n",
    "            treestr += \")\"\n",
    "        return treestr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an example of transition-based dependancy parsing in the paper. Model's goal is to predict correct transition of parser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stack : ['ROOT'] \n",
      "buffer : ['He', 'has', 'good', 'control', '.']\n",
      "stack : ['ROOT', 'He', 'has'] \n",
      "buffer : ['good', 'control', '.']\n",
      "stack : ['ROOT', 'has'] \n",
      "buffer : ['good', 'control', '.']\n",
      "[{'pos': 'VBZ', 'graph_id': 0, 'addr': 1, 'head': 0, 'form': 'has'}]\n",
      "stack : ['ROOT', 'has', 'good', 'control'] \n",
      "buffer : ['.']\n",
      "stack : ['ROOT', 'has', 'control'] \n",
      "buffer : ['.']\n",
      "stack : ['ROOT', 'has'] \n",
      "buffer : ['.']\n",
      "stack : ['ROOT', 'has'] \n",
      "buffer : []\n",
      "stack : ['ROOT'] \n",
      "buffer : []\n",
      "[{'pos': 'VBZ', 'graph_id': 0, 'addr': 1, 'head': 0, 'form': 'has'}, {'pos': 'NN', 'graph_id': 1, 'addr': 3, 'head': 2, 'form': 'control'}, {'pos': 'VBZ', 'graph_id': 2, 'addr': 1, 'head': 3, 'form': 'has'}, {'pos': 'VBZ', 'graph_id': 3, 'addr': 1, 'head': 4, 'form': 'has'}, {'pos': '<root>', 'graph_id': 4, 'addr': -1, 'head': 1, 'form': 'ROOT'}]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state = TransitionState(nltk.pos_tag(\"He has good control .\".split()))\n",
    "print(state)\n",
    "state.shift()\n",
    "state.shift()\n",
    "print(state)\n",
    "state.left_arc()\n",
    "print(state)\n",
    "print(state.arcs)\n",
    "state.shift()\n",
    "state.shift()\n",
    "print(state)\n",
    "state.left_arc()\n",
    "print(state)\n",
    "state.right_arc()\n",
    "print(state)\n",
    "state.shift()\n",
    "state.right_arc()\n",
    "print(state)\n",
    "state.right_arc()\n",
    "print(state)\n",
    "print(state.arcs)\n",
    "state.is_done()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'(ROOT (has He (control good) .))'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state.to_tree_string()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEMCAMAAAAf2ZYHAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACTUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkABZkABZkABZkBdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUhdSUhdSUhdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVwBZkBdSUj+PV////5bcvbMAAAAtdFJOUwDumRGI3cxEuyIzZlV3qiJEu4iZfhF33cxg\n7jOqZlUiRIju3ZnMZrt3ETNVqt3J4RAAAAABYktHRACIBR1IAAAACXBIWXMAAABIAAAASABGyWs+\nAAAAB3RJTUUH4QsCCx8TiaQ7vAAABQBJREFUeNrt3W1XskoYhuFBEZDMzLblfrIUXzKqkf//7/bM\nAIY9vZg7vFdyHh8QgeUar5FhdM2gUgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADwK2nH\na7XNetvvaN3x7aryg2I11AVfuqgnSOsgiiKzVKrraW1Xva5qB3a7W43z3VF0Jl3UE6R1aJaxfYh0\n0DV1EOhI+dqLla2DqHIQftw2/bhdZNzT2sTu2hnT6rRfD8KPe215TNTlprCMu3gk/ZoUl1Rz1SX9\n43PBBvrctPJa9+yWSsvTo+Wplws2tJ0b1S+vui171Q3dVbf1ehB+Xh5sy3Zu2p28xxm07dngepyd\nduUg/Lg82Lbn2h77bSvIv22dB9svXqQPAEB9NtIFaLTNhXQJmmwzkC5Bk5G+JNKXRPqSLgfSJWgy\n0pdE+pIuh9IlaDLSl0T6kkhf0hXpCxqSPgCcviiSLkGTlcM4IYH0JWnP93QUK2VHtPXbSsV9O5xN\nulgNoT3PD3RfnelOeG7HdXb0WdjSTBg6CjtYs2uanzjs5s2QHVHeZgDncbh23yy6fTefwg5s3o6p\nRd3K9CM7c8I9OevbMeXS5WqGMn3bArlZRN2um04nXa5mKNPvmObGfOZ7oQ56oc9n/zjK9OOO7sS+\n9lQvKCexAwBO2uCfgXQRmmp0fbPZbG6uR9IFaaDB1WZ8PRr/e33z5+pWujDNcjEsMp8M83q44wQ4\nlvvJNu7hxC63lYGa7SZ9Py5XbI1Il+3UmZAvq63M7eukRVstQybR1eZi+OevfDfVBuf+0tSNdClP\n09270Y7vd56+V0H4vz5MdfLXUEJbS/d7vCT2MzKBTj4INO/07DJVdcMJ8DNurz7Lctvp2XU3/rC+\nsLfRVzHefnSnhs8rDV/bJ8HNh9+ybM1dDaTfwy81ut4rvPGnZ4apPn6FO8Bwz18vJ5/Pn7A/hDLD\n4tsGg/2O+7qG9n0lAI01nVUfDsYEi4MkWfVhL/E7STOw7SAHpH/+TtKkf5Bq+sk8WyTlDjterd81\nK2f2zss927i4sWzKd+OY+1HsRa87Sf8gSTazlib9ZDVVs/lDvr2nPb9l73XtmxXfDti3g8btH36E\nWvu+edYJ/OpO6TfyKyVZQSkXfLLOt7v77Lf8bv6vH74OtnMoymGd+Y35tzul38ivVGl5ttVglTcX\nz/93InTjx1U5c6VYebsT31VNf1bZTvrHUEl/nlS2u5bHj8L8Xz+2jctu+m924rsq6SfzmZquHvPt\n+VXXy6+6LXsmVNI/t5Vij9rdWZgl3y1EY73pcc6TtNix0+Psh5WPvOnkeGXcOzt3XxI1YdYKADQH\n/z4hifQlkb4k0pdE+pJIXxLpSyJ9SaQvifQlkb4k0pdE+pJIXxLpA0CTrZ+kS9BkO6MPcWSkX690\nnS2es0Spx0WWvZiGZvqSuUkWZsf8gfTrtVylJueZesweVPoyT9PFeqpmpjqW86d0Sfr1Wjwo9WRC\nXr0ou/L8nNkL7ctKLZbmPCD9erl8zSJzI5OzJB8ja5ZuA+nXi/QlzWl5BJVX3Qdz1Z2uF+ZCsErd\nk+V8miakXy/b45zZkF2Pc1r0OB+LHuf6Wbp8py/lIy5kuU7VciFdiqZKTUOz4rc0AAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4Ff6D/sti7BIfp0T\nAAAAJXRFWHRkYXRlOmNyZWF0ZQAyMDE3LTExLTAyVDIwOjMxOjE5KzA5OjAwmfE8pwAAACV0RVh0\nZGF0ZTptb2RpZnkAMjAxNy0xMS0wMlQyMDozMToxOSswOTowMOishBsAAAAjdEVYdHBzOkhpUmVz\nQm91bmRpbmdCb3gAMzgxeDI2OC0xOTAtMTMzZk0dLQAAABx0RVh0cHM6TGV2ZWwAQWRvYmUtMy4w\nIEVQU0YtMy4wCptwu+MAAAAidEVYdHBzOlNwb3RDb2xvci0wAGZvbnQgTGliZXJhdGlvblNhbnP+\nGafGAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "draw_nltk_tree(Tree.fromstring(state.to_tree_string()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data load & Preprocessing "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get features from stack and buffer. <br>\n",
    "1.  The top 3 words on the stack and buffer ($s_1,s_2,s_3,b_1,b_2,b_3$)\n",
    "2. The first and second leftmost / rightmost children of the top two words on the stack: $lc_1(s_i), rc_1(s_i), lTc_2(s_i), rc_2(s_i), i = 1, 2$\n",
    "3. <i>The leftmost of leftmost / rightmost of rightmost children of the top two words on the stack: $lc_1(lc_1(s_i)), rc_1(rc_1(s_i)), i = 1, 2$. # I don't use these features</i>\n",
    "4. POS tags for $S^t$\n",
    "5. <i>corresponding arc labels of words excluding those 6 words on the stack/buffer for $S^l$ # I don't use these features</i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_feat(transition_state, word2index, tag2index, label2index=None):\n",
    "    word_feats = []\n",
    "    tag_feats = []\n",
    "    \n",
    "    word_feats.append(transition_state.stack[-1][0]) if len(transition_state.stack) >= 1 and \\\n",
    "    transition_state.stack[-1][0] in word2index.keys() else word_feats.append('<NULL>') # s1\n",
    "    word_feats.append(transition_state.stack[-2][0]) if len(transition_state.stack) >= 2 and \\\n",
    "    transition_state.stack[-2][0] in word2index.keys() else word_feats.append('<NULL>') # s2\n",
    "    word_feats.append(transition_state.stack[-3][0]) if len(transition_state.stack) >= 3 and \\\n",
    "    transition_state.stack[-3][0] in word2index.keys() else word_feats.append('<NULL>') # s3\n",
    "    \n",
    "    tag_feats.append(transition_state.stack[-1][1]) if len(transition_state.stack) >= 1 and \\\n",
    "    transition_state.stack[-1][1] in tag2index.keys() else tag_feats.append('<NULL>') # st1\n",
    "    tag_feats.append(transition_state.stack[-2][1]) if len(transition_state.stack) >= 2 and \\\n",
    "    transition_state.stack[-2][1] in tag2index.keys() else tag_feats.append('<NULL>') # st2\n",
    "    tag_feats.append(transition_state.stack[-3][1]) if len(transition_state.stack) >= 3 and \\\n",
    "    transition_state.stack[-3][1] in tag2index.keys() else tag_feats.append('<NULL>') # st3\n",
    "    \n",
    "    \n",
    "    word_feats.append(transition_state.buffer[0][0]) if len(transition_state.buffer) >= 1 and \\\n",
    "    transition_state.buffer[0][0] in word2index.keys() else word_feats.append('<NULL>') # b1\n",
    "    word_feats.append(transition_state.buffer[1][0]) if len(transition_state.buffer) >= 2 and \\\n",
    "    transition_state.buffer[1][0] in word2index.keys() else word_feats.append('<NULL>') # b2\n",
    "    word_feats.append(transition_state.buffer[2][0]) if len(transition_state.buffer) >= 3 and \\\n",
    "    transition_state.buffer[2][0] in word2index.keys() else word_feats.append('<NULL>') # b3\n",
    "    \n",
    "    tag_feats.append(transition_state.buffer[0][1]) if len(transition_state.buffer) >= 1 and \\\n",
    "    transition_state.buffer[0][1] in tag2index.keys() else tag_feats.append('<NULL>') # bt1\n",
    "    tag_feats.append(transition_state.buffer[1][1]) if len(transition_state.buffer) >= 2 and \\\n",
    "    transition_state.buffer[1][1] in tag2index.keys() else tag_feats.append('<NULL>') # bt2\n",
    "    tag_feats.append(transition_state.buffer[2][1]) if len(transition_state.buffer) >= 3 and \\\n",
    "    transition_state.buffer[2][1] in tag2index.keys() else tag_feats.append('<NULL>') # bt3\n",
    "    \n",
    "    \n",
    "    lc_s1 = transition_state.get_left_most(transition_state.stack[-1][2]) if len(transition_state.stack) >= 1 \\\n",
    "    else transition_state.get_left_most(None)\n",
    "    rc_s1 = transition_state.get_right_most(transition_state.stack[-1][2]) if len(transition_state.stack) >= 1 \\\n",
    "    else transition_state.get_right_most(None)\n",
    "    \n",
    "    lc_s2 = transition_state.get_left_most(transition_state.stack[-2][2]) if len(transition_state.stack) >= 2 \\\n",
    "    else transition_state.get_left_most(None)\n",
    "    rc_s2 = transition_state.get_right_most(transition_state.stack[-2][2]) if len(transition_state.stack) >= 2 \\\n",
    "    else transition_state.get_right_most(None)\n",
    "    \n",
    "    words, tags, _ = zip(*[lc_s1, rc_s1, lc_s2, rc_s2])\n",
    "    \n",
    "    word_feats.extend(words)\n",
    "    \n",
    "    tag_feats.extend(tags)\n",
    "    \n",
    "    \n",
    "    return prepare_sequence(word_feats, word2index).view(1, -1), prepare_sequence(tag_feats, tag2index).view(1, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can get data from <a href=\"https://github.com/rguthrie3/DeepDependencyParsingProblemSet\">this repo</a>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = open('../dataset/dparser/train.txt', 'r').readlines()\n",
    "vocab = open('../dataset/dparser/vocab.txt', 'r').readlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "splited_data = [[nltk.pos_tag(d.split('|||')[0].split()), d.split('|||')[1][:-1].split()] for d in data]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_x,train_y = list(zip(*splited_data))\n",
    "train_x_f = flatten(train_x)\n",
    "sents, pos_tags = list(zip(*train_x_f))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tag2index = {v:i for i,v in enumerate(set(pos_tags))}\n",
    "tag2index['<root>'] = len(tag2index)\n",
    "tag2index['<NULL>'] = len(tag2index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "vocab = [v.split('\\t')[0] for v in vocab]\n",
    "word2index = {v:i for i, v in enumerate(vocab)}\n",
    "word2index['ROOT'] = len(word2index)\n",
    "word2index['<NULL>'] = len(word2index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "actions = ['SHIFT', 'REDUCE_L', 'REDUCE_R']\n",
    "action2index = {v:i for i, v in enumerate(actions)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for tx, ty in splited_data:\n",
    "    state = TransitionState(tx)\n",
    "    transition = ty + ['REDUCE_R'] # root\n",
    "    while len(transition):\n",
    "        feat = get_feat(state, word2index, tag2index)\n",
    "        action = transition.pop(0)\n",
    "        actionTensor = Variable(LongTensor([action2index[action]])).view(1, -1)\n",
    "        train_data.append([feat, actionTensor])\n",
    "        if action == 'SHIFT':\n",
    "            state.shift()\n",
    "        elif action == 'REDUCE_R':\n",
    "            state.right_arc()\n",
    "        elif action == 'REDUCE_L':\n",
    "            state.left_arc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(Variable containing:\n",
       "   9151  9152  9152  2106     2   353  9152  9152  9152  9152\n",
       "  [torch.cuda.LongTensor of size 1x10 (GPU 0)], Variable containing:\n",
       "     43    44    44    25    22    26    44    44    44    44\n",
       "  [torch.cuda.LongTensor of size 1x10 (GPU 0)]), Variable containing:\n",
       "  0\n",
       " [torch.cuda.LongTensor of size 1x1 (GPU 0)]]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/05.neural-dparser-architecture.png\">\n",
    "<center>borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture6.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class NeuralDependencyParser(nn.Module):\n",
    "    \n",
    "    def __init__(self, w_size, w_embed_dim, t_size, t_embed_dim, hidden_size, target_size):\n",
    "        \n",
    "        super(NeuralDependencyParser, self).__init__()\n",
    "        \n",
    "        self.w_embed =  nn.Embedding(w_size, w_embed_dim)\n",
    "        self.t_embed = nn.Embedding(t_size, t_embed_dim)\n",
    "        self.hidden_size = hidden_size\n",
    "        self.target_size = target_size\n",
    "        self.linear = nn.Linear((w_embed_dim + t_embed_dim) * 10, self.hidden_size)\n",
    "        self.out = nn.Linear(self.hidden_size, self.target_size)\n",
    "        \n",
    "        self.w_embed.weight.data.uniform_(-0.01, 0.01) # init\n",
    "        self.t_embed.weight.data.uniform_(-0.01, 0.01) # init\n",
    "        \n",
    "    def forward(self, words, tags):\n",
    "        \n",
    "        wem = self.w_embed(words).view(words.size(0), -1)\n",
    "        tem = self.t_embed(tags).view(tags.size(0), -1)\n",
    "        inputs = torch.cat([wem, tem], 1)\n",
    "        h1 = torch.pow(self.linear(inputs), 3) # cube activation function\n",
    "        preds = -self.out(h1)\n",
    "        return F.log_softmax(preds,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "STEP = 5\n",
    "BATCH_SIZE = 256\n",
    "W_EMBED_SIZE = 50\n",
    "T_EMBED_SIZE = 10\n",
    "HIDDEN_SIZE = 512\n",
    "LR = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = NeuralDependencyParser(len(word2index), W_EMBED_SIZE, len(tag2index), T_EMBED_SIZE, HIDDEN_SIZE, len(action2index))\n",
    "if USE_CUDA:\n",
    "    model = model.cuda()\n",
    "\n",
    "loss_function = nn.NLLLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "losses = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean_loss : 1.10\n",
      "mean_loss : 0.82\n",
      "mean_loss : 0.38\n",
      "mean_loss : 0.32\n",
      "mean_loss : 0.29\n",
      "mean_loss : 0.27\n",
      "mean_loss : 0.25\n",
      "mean_loss : 0.24\n",
      "mean_loss : 0.23\n",
      "mean_loss : 0.22\n",
      "mean_loss : 0.22\n",
      "mean_loss : 0.21\n",
      "mean_loss : 0.20\n",
      "mean_loss : 0.20\n",
      "mean_loss : 0.19\n",
      "mean_loss : 0.19\n"
     ]
    }
   ],
   "source": [
    "for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
    "    \n",
    "    model.zero_grad()\n",
    "    inputs, targets = list(zip(*batch))\n",
    "    words, tags = list(zip(*inputs))\n",
    "    words = torch.cat(words)\n",
    "    tags = torch.cat(tags)\n",
    "    targets = torch.cat(targets)\n",
    "    preds = model(words, tags)\n",
    "    loss = loss_function(preds, targets.view(-1))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    losses.append(loss.data.tolist()[0])\n",
    "    \n",
    "    if i % 100 == 0:\n",
    "        print(\"mean_loss : %0.2f\" %(np.mean(losses)))\n",
    "        losses = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test (UAS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dev = open('../dataset/dparser/dev.txt','r').readlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "splited_data = [[nltk.pos_tag(d.split('|||')[0].split()), d.split('|||')[1][:-1].split()] for d in dev]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dev_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for tx,ty in splited_data:\n",
    "    state = TransitionState(tx)\n",
    "    transition = ty + ['REDUCE_R'] # root\n",
    "    while len(transition) != 0:\n",
    "        feat = get_feat(state, word2index, tag2index)\n",
    "        action = transition.pop(0)\n",
    "        dev_data.append([feat, action2index[action]])\n",
    "        if action == 'SHIFT':\n",
    "            state.shift()\n",
    "        elif action == 'REDUCE_R':\n",
    "            state.right_arc()\n",
    "        elif action == 'REDUCE_L':\n",
    "            state.left_arc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "accuracy = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "91.6377171215881\n"
     ]
    }
   ],
   "source": [
    "for dev in dev_data:\n",
    "    input, target = dev[0], dev[1]\n",
    "    word, tag = input[0], input[1]\n",
    "    pred = model(word, tag).max(1)[1]\n",
    "    pred = pred.data.tolist()[0]\n",
    "    if pred == target:\n",
    "        accuracy += 1\n",
    "\n",
    "print(accuracy/len(dev_data) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting parsed result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test = TransitionState(nltk.pos_tag(\"I shot an elephant in my pajamas\".split()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "index2action = {i:v for v, i in action2index.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "while test.is_done() == False:\n",
    "    feat = get_feat(test, word2index, tag2index)\n",
    "    word, tag = feat[0], feat[1]\n",
    "    action = model(word, tag).max(1)[1].data.tolist()[0]\n",
    "    \n",
    "    action = index2action[action]\n",
    "    \n",
    "    if action == 'SHIFT':\n",
    "        test.shift()\n",
    "    elif action == 'REDUCE_R':\n",
    "        test.right_arc()\n",
    "    elif action == 'REDUCE_L':\n",
    "        test.left_arc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stack : ['ROOT'] \n",
      "buffer : []\n"
     ]
    }
   ],
   "source": [
    "print(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'(ROOT (shot I (elephant an) (in (pajamas my))))'"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.to_tree_string()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEMCAMAAAAf2ZYHAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACZUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkABZkABZkABZkBdSUhdTUxdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVwBakj+PVwBZkBdSUj+PV////zBJQHUAAAAvdFJOUwDumRGI3cxEuyIzZlV3qlWHEWaq\n7vT33cyIRCIzd5m7IkSI3aozu+53ZhGZzDBVAXcCsQAAAAFiS0dEAIgFHUgAAAAJcEhZcwAAAEgA\nAABIAEbJaz4AAAAHdElNRQfhCwILJR1LtcjCAAAGuUlEQVR42u3deUOjOBjH8VDO0lrX6q73WY/R\nzhr6/t/cJuEo1NKpYn1W+X7+qJRB55kfMQQnEaUAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAA4FvSjjfwzbYfhFqHgd1UQVRsxroQSJf6A2kdJUliXpUaelrbTW+o/Mjud5tp/sdJMpIu9QfS\nOjavqf2Q6GhozkGkExVoL1X2HCS1g/DpqvRTv8h4rLWJ3fUzptfxlwfh0y17HhN1uSsu4y4+kv6O\nFJdUc9Ul/a/ngo30nunltR7bPbWeZ0zPs1su2NgObtSkvOoO7FU3dlfdwfIgfL482IEd3PhhPuKM\nfPvd4EacoV87CJ8uD9b3XN9j77ai/G5rL6puvEgfAIDdyaQL6LVMuoBey6QL6LVMuoBey6QL6LVM\nuoBey6QL6LVMuoBey6QL6LVMuoBey6QL6LVMuoBey6QL6LVMuoBey6QLAAB8kSSRrqDPyrmcS3uc\njy/zNv1Qf+Tr4B2GE609O3PNvHo6SZWKJ3YuW75uiIlsuxXqvTiw85a15wWRntjUB+Y8BMNAmxfp\n8n44N3k89vMZm0PT/bgp5G7piqbn2bWBLqbPuqzNS97d2FfS373RxE4fJ30R/tCtnFum73qeMT3P\nV/A9bxzv1dt+edW1b0ZcdXcrnRTr1cv03YgzsgukA0ackmj5AL7C/l/7B9I19NPB9PDo73+Os5PT\nM+lS+uZ8epgdn54dT9XF5WF2dHh5Ll1Rb5xfnpgWf6HU9DjfcXZ6Ys4FndDuXZikT4qmbpp+6WB6\ndZRd21OCXTk7Pc4Op2UvUzb90sXldXZ0NaUT2oX9q6Ojw2mte6k1/eVBp/Y6vC9d689iBzhHV81Q\nV5t+yVyRTSd0SSf0OYoBzurudU2/dHF6nR1fTbkOd1QNcFa1Nf3Swf4VNwOd1Ac4qzY1/dL5pemE\nuBn4iOYAZ9Wfmn7ty5ibgStuBt7jzQBn1TZNv2RuBo7Xd194Y90AZ9XWTb+U/0SCm4HNWgY4q97T\n9Cv7/ERio/OT7XqIdzf96i9wP5HgO2C9bccnXfK7uJT+VwL4X7m5fbtv7TqJDy+eYM5Vu7vF231r\n89oQYrzxxAQ8maLVZ6Q/oHV/UD39INSebag26HHkfve7miRp5PZq7Sc6TPPD9MQ3LT5JQ7sn0Vpv\naP128mF5KJpq6Qd6UKyTsJM1o3iUz94MA7tMwj4LIXCP+xjpsJzYme8ZmT/e8OSVfAZu8cloqKXv\n6XEca8/lNdFBHEc6zRdPBDqqFk+oNB6WkRZ7Nl9XG4eiYZl+WjxrwuXl5dvFHP3YTdtXeZJ2IZeu\nTynfKn3F4GeNRvrFlks/rrZX0k+Kzon0u2v0PCZxe2XMe55yu+p51LIXiUn/UzSuul4w0RMX09hc\nSQP7fIniqjtaBhiaK2hoT0ot/b0N8/hJv92bEae/HHFO3OW1GnGq/MWNHc2ZWu4xZ2fS/jeQ/sdt\nFZnvS5f5Q9FgJZE++u7sustnn3b6bHRLv9tng/Qlkb6kjuln0vV/b6QvifQlkb6kbumfZ9L1f28d\nRy2ZdP3fG+lL6po+k5i76Jo+a+m6IH1JpA8A+AP+S1gSyybwzd3NFov7B6Xubx8XT79aD0uTZBzq\nxK9WV7iexx947vcI5+sngmIBRprkR1TPtUCL58Wzepk/KrX4faN+LW5aD7QrJOwvx66trrCLiqLy\nzcREPQkiHbmVAm5n9VwLtDHNXt0uTPrPZmNx23qcndu8p8P66gql4jgt36R+uYbCX+4snmuBFi9z\n0/MsFkXwG9PPJ/o3Vlek9TfVXNpRmO+snmuBFo+/H4q2v2X69dUVvmdb92r6pv2n+ZvyuRZo4fK+\n2yr91PU89dUVcXlKGunbHW6lTPVcC7SYzdXLr6fFyxbpJ+6qW19dMbQX2rBc/1Vr+4Mo1ON/q+da\noMXDbPE0v5k9bpG+GXGawWVjdcXIM4NQ+5SuZr/v6UHq6b3quRbo6k0PQpfyhVbDDkj/C62GrbmN\nAlg9IYoZ5JJIXxLpS2IOsyTSl0T6kkhfEqsnRGXSBfRaJl1Ar7F6QhIzyCWRviTSBwAAwE64FRfP\nD7PF7PbVTf2/+y1dUo/kKy4eb9R8pmZ3ZsfsWbqkHslXXDy7Keh3M6VeFy/SJfVINfPWpH9jPs7n\n0hX1ST19dT9XT7edvyS21kj/9el5Jl1QrzTSV7OnO+mCeqWZ/qb1pti1+b10BT32uniQLqG/zE2v\ndAkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAGv8B+d3Lyt1lXwEAAAAldEVY\ndGRhdGU6Y3JlYXRlADIwMTctMTEtMDJUMjA6Mzc6MjkrMDk6MDAaYEsDAAAAJXRFWHRkYXRlOm1v\nZGlmeQAyMDE3LTExLTAyVDIwOjM3OjI5KzA5OjAwaz3zvwAAACN0RVh0cHM6SGlSZXNCb3VuZGlu\nZ0JveAAzODF4MjY4LTE5MC0xMzNmTR0tAAAAHHRFWHRwczpMZXZlbABBZG9iZS0zLjAgRVBTRi0z\nLjAKm3C74wAAACJ0RVh0cHM6U3BvdENvbG9yLTAAZm9udCBMaWJlcmF0aW9uU2Fuc/4Zp8YAAAAA\nSUVORK5CYII=\n",
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "draw_nltk_tree(Tree.fromstring(test.to_tree_string()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Further Topics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* <a href=\"https://arxiv.org/pdf/1506.06158.pdf\">Structured Training for Neural Network Transition-Based Parsing</a>\n",
    "* <a href=\"https://arxiv.org/pdf/1703.04474\">DRAGNN: A Transition-based Framework for Dynamically Connected Neural Networks</a>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
